package com.x.mapper.interfaces;

import com.x.mapper.annotations.DtoResult;
import com.x.mapper.example.SelectDtoByExampleMapper;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ResultFlag;
import org.apache.ibatis.mapping.ResultMap;
import org.apache.ibatis.mapping.ResultMapping;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import tk.mybatis.mapper.entity.EntityColumn;
import tk.mybatis.mapper.entity.EntityTable;
import tk.mybatis.mapper.mapperhelper.EntityHelper;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.*;
import java.util.concurrent.ConcurrentSkipListSet;

/**
 * 拦截selectDtoByExample方法,设置resultMap
 * @author 252944454@qq.com
 * @date 2016/12/14
 */
@Intercepts(@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}))
public class DtoInterceptor implements Interceptor {

    /**
     * 默认拦截
     */

   public final static  DtoInterceptor dtoInterceptor= new DtoInterceptor();
    static {
        dtoInterceptor.setEnable(true);
        dtoInterceptor.add(SelectDtoByExampleMapper.class);
    }




    Properties properties=new Properties();

    /**
     * 拦截的Dto
     */
    Set<DtoMeta> dtoMetas= new ConcurrentSkipListSet<>();

    boolean enable = true;

    public DtoInterceptor() {

    }
    @Override
    public Object intercept(Invocation invocation) throws Throwable {

       DtoMeta dtoMeta = getDtoMeta(invocation);

       return this.innerIntercept(invocation,dtoMeta);
    }


    protected Object innerIntercept(Invocation invocation,DtoMeta dtoMeta) throws Throwable {
        if(dtoMeta == null){
            return  invocation.proceed();
        }

        final Object[] args = invocation.getArgs();
        //获取原始的ms
        MappedStatement ms = (MappedStatement) args[0];


        //获取dto cls
        Object params = args[1];
        if (params == null) {
            return invocation.proceed();
        }
        int dtoClassIndex = dtoMeta.getArgIndex();
        Class<?> dtoCls = (Class) ((Map) params).get(""+dtoClassIndex);
        if (dtoCls == null) {
            return invocation.proceed();
        }

        //构造dto resultMap
        List<ResultMap> resultMaps = new ArrayList<>();
        ResultMap resultMap = dtoResultMap(ms,dtoCls);
        if(resultMap == null){
            return invocation.proceed();
        }
        resultMaps.add(resultMap);

        MetaObject metaObject = SystemMetaObject.forObject(ms);
        metaObject.setValue("resultMaps", resultMaps);

        //todo 可以设置回原来的值ms

        return invocation.proceed();
    }


    DtoMeta getDtoMeta(Invocation invocation){
        final Object[] args = invocation.getArgs();
        //获取原始的ms
        MappedStatement ms = (MappedStatement) args[0];

        DtoMeta dtoMeta = null;
        for(DtoMeta meta:dtoMetas){
            if(ms.getId().endsWith(meta.getMethodName())){
                dtoMeta = meta;
                break;
            }
        }

        return dtoMeta;
    }

    public Class<?> getEntityClass(MappedStatement ms){
         Class mapperClass = getMapperClass(ms.getId());

        //获取接口类型
       Type[] types= mapperClass.getGenericInterfaces();

        if(types==null||types.length==0 ){
            return null;
        }
        if(!(types[0] instanceof ParameterizedType)){
            return null;
        }
        //参数类型
        ParameterizedType type = (ParameterizedType)types[0];

        //参数实际类型参数
        types = type.getActualTypeArguments();
        if(types==null||types.length==0){
            return null;
        }

        return (Class)types[0];
    }







    /**
     * Dto 实体 mapper
     *
     * @param ms
     * @param dtoClass
     * @return
     */
    protected ResultMap dtoResultMap(MappedStatement ms, Class<?> dtoClass) {

        EntityTable entityTable = EntityHelper.getEntityTable(getEntityClass(ms));

        Configuration configuration = ms.getConfiguration();

        Set<EntityColumn> entityClassColumns = entityTable.getEntityClassColumns();
        if (entityClassColumns == null || entityClassColumns.size() == 0) {
            return null;
        }

        List<ResultMapping> resultMappings = new ArrayList<ResultMapping>();
        for (EntityColumn entityColumn : entityClassColumns) {
            ResultMapping.Builder builder = new ResultMapping.Builder(configuration, entityColumn.getProperty(), entityColumn.getColumn(), entityColumn.getJavaType());
            if (entityColumn.getJdbcType() != null) {
                builder.jdbcType(entityColumn.getJdbcType());
            }
            if (entityColumn.getTypeHandler() != null) {
                try {
                    builder.typeHandler(entityColumn.getTypeHandler().newInstance());
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            List<ResultFlag> flags = new ArrayList<ResultFlag>();
            if (entityColumn.isId()) {
                flags.add(ResultFlag.ID);
            }
            builder.flags(flags);
            resultMappings.add(builder.build());
        }
        ResultMap.Builder builder = new ResultMap.Builder(configuration, "BaseMapperResultMap", dtoClass, resultMappings, true);

        ResultMap map = builder.build();
        return map;
    }

    Class<?> getMapperClass(String msId) {
        if (msId.indexOf(".") == -1) {
            throw new RuntimeException("当前MappedStatement的id=" + msId + ",不符合MappedStatement的规则!");
        }
        String mapperClassStr = msId.substring(0, msId.lastIndexOf("."));
        try {
            return Class.forName(mapperClassStr);
        } catch (ClassNotFoundException e) {
            return null;
        }
    }




    @Override
    public Object plugin(Object target) {
        if(!this.enable || !(target instanceof Executor)){
            return target;
        }
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {

    }

    public boolean isEnable() {
        return enable;
    }

    public void setEnable(boolean enable) {
        this.enable = enable;
    }



    public DtoInterceptor add(Class cls){

        addDtoMeta(cls);

        return  this;
    }

    private void addDtoMeta(Class cls) {
        if(cls==null){
            return;
        }
        Method[]  methods =  cls.getDeclaredMethods();
        for(Method method :methods){
            proceesDtoMeta(method);
        }
    }

    private void proceesDtoMeta(Method method) {
        Annotation[][] an = null;
        an =  method.getParameterAnnotations();
        if(an.length>0){
            for(int i=0;i<an.length;i++){
                for(int j=0;j<an[i].length;j++){
                    Annotation t =an[i][j];
                   if(DtoResult.class.getTypeName().equals( t.annotationType().getName())){
                       add(new DtoMeta(method.getName(),i));
                       break;
                   }
                }
            }
        }
    }

    protected DtoInterceptor add(DtoMeta dtoMeta){
        if(dtoMeta==null){
            return this;
        }
        this.dtoMetas.add(dtoMeta);
        return this;
    }




    protected static class DtoMeta implements Comparable{
          String methodName;
          int argIndex;

       public DtoMeta(String methodName, int argIndex) {
           this.methodName = methodName;
           this.argIndex = argIndex;
       }

       public String getMethodName() {
              return methodName;
          }

          public void setMethodName(String methodName) {
              this.methodName = methodName;
          }

          public int getArgIndex() {
              return argIndex;
          }

          public void setArgIndex(int argIndex) {
              this.argIndex = argIndex;
          }

       @Override
       public boolean equals(Object o) {
           if (this == o) return true;
           if (o == null || getClass() != o.getClass()) return false;

           DtoMeta dtoMeta = (DtoMeta) o;

           return methodName.equals(dtoMeta.methodName);

       }

       @Override
       public int hashCode() {
           return methodName.hashCode();
       }

        @Override
        public int compareTo(Object o) {

            return this.getMethodName().compareTo(((DtoMeta)o).getMethodName());
        }
    }


}
